#    Copyright 2012, Robert Baruch (robert.c.baruch@gmail.com)
#
#    This file is part of PYZ.
#
#    PYZ is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    PYZ is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with PYZ.  If not, see <http://www.gnu.org/licenses/>.


'''
Created on May 31, 2012

@author: robertbaruch
'''

import argparse
import os
import struct
import sys
import random
import time
import itertools

zchars_alphabet = [ 'abcdefghijklmnopqrstuvwxyz',
                    'ABCDEFGHIJKLMNOPQRSTUVWXYZ',
                    ' \n0123456789.,!?_#\'"/\\-:()' ]

bzchars_alphabet = [ b'abcdefghijklmnopqrstuvwxyz',
                    b'ABCDEFGHIJKLMNOPQRSTUVWXYZ',
                    b' \n0123456789.,!?_#\'"/\\-:()' ]

def strtozstr(string):
    '''Convert a Python string to a z-encoded byte string'''
    return zsciitozstr(bytes([ord(c) 
                              for c in string 
                              if ord(c) in bzchars_alphabet[0] or ord(c) in bzchars_alphabet[1] or ord(c) in bzchars_alphabet[2]]))

def zsciitozstr(zscii, length = 0):
    '''Convert a byte string of zscii chars to a z-encoded byte string. Optionally set the string length
       to the given length in words. The string will be padded if the string is too small, or truncated if the
       string is too large.'''
    zchars = []
    for z in zscii:
        if z == ord(' '):
            zchars.append(0)
        elif ord('a') <= z <= ord('z'):
            zchars.append(z - ord('a') + 6)
        elif ord('A') <= z <= ord('Z'):
            zchars.append(4)
            zchars.append(z - ord('A') + 6)
        elif z in bzchars_alphabet[2]:
            zchars.append(5)
            zchars.append(bzchars_alphabet[2].index(bytes([z])) + 6)
        
    # if len(zchars) is not a multiple of three, then pad out with zshifts
    
    while len(zchars)%3 != 0:
        zchars.append(5)
                    
    zstr = []
    for i in range(0, len(zchars), 3):
        zstr.append(zchars[i] << 2 | zchars[i+1] >> 3)
        zstr.append((zchars[i+1] << 5 | zchars[i+2]) & 0xFF)
        
    if length != 0:
        if len(zstr) > length<<1:
            zstr = zstr[0:length<<1]
        elif len(zstr) < length<<1:
            zstr += [0x14, 0xA5] * (length-(len(zstr)>>1))
            
    zstr[-2] |= 0x80
    return bytes(zstr)
    


def word_to_signed(w):
    return w if w < 0x8000 else w-0x10000

class IllegalInstructionError(Exception):
    def __init__(self, loc, opcode):
        self.loc = loc
        self.opcode = opcode
        
    def __str__(self):
        return "Illegal instruction {:02X} at location {:04X}".format(self.opcode, self.loc)

class Memory:
    def __init__(self, zfilesize = 0, zfile = None, zmem = None, ostr = sys.stdout, istr = sys.stdin, sstr = None):
        if zfile != None:
            self.memory = bytearray(zfilesize)
            zfile.readinto(self.memory)
        if zmem != None:
            self.memory = bytearray(zmem)
        self.header = Header(self.memory)
        self.output_stream = ostr
        self.input_stream = istr
        self.status_thing = self if sstr == None else sstr
        
    default_translation = ("\u00E4\u00F6\u00FC\u00C4\u00D6\u00DC\u00DF\u00BB"
                           "\u00AB\u00EB\u00EF\u00FF\u00CB\u00CF\u00E1\u00E9"
                           "\u00ED\u00F3\u00FA\u00FD\u00C1\u00C9\u00CD\u00D3"
                           "\u00DA\u00DD\u00E0\u00E8\u00EC\u00F2\u00F9\u00C0"
                           "\u00C8\u00CC\u00D2\u00D9\u00E2\u00EA\u00EE\u00F4"
                           "\u00FB\u00C2\u00CA\u00CE\u00D4\u00DB\u00E5\u00C5"
                           "\u00F8\u00D8\u00E3\u00F1\u00F5\u00C3\u00D1\u00D5"
                           "\u00E6\u00C6\u00E7\u00C7\u00FE\u00F0\u00DE\u00D0"
                           "\u00A3\u0153\u0152\u00A1\u00BF")
    
    def set_output_stream(self, ostr):
        self.output_stream = ostr
        
    def print_char(self, c):
        print(c, end='', file=self.output_stream)
        
    def print_num(self, num):
        print(str(num), end='', file=self.output_stream)
        
    def print_str(self, s):
        print(s, end='', file=self.output_stream)

    def zchar_to_char(self, z):
        if z <= 126:
            return chr(z)
        elif z == 13:
            return '\n'
        elif 155 <= z <= 223:
            return Memory.default_translation[z-155]
        else:
            return ''
        
    def char_to_zchar(self, c):
        if c == '\n':
            return 13
        if ord(c) <= 126:
            return ord(c)
        default_index = Memory.default_translation.find(c)
        return 0 if default_index == -1 else default_index + 155
        
    def print_zscii(self, z):
        self.print_char(self.zchar_to_char(z))
            
    def print_zstr_at(self, loc):
        size = self.zchar_size(loc)
        for z in self.zstrtozchars(self.memory[loc : loc + size]):
            self.print_zscii(z)
            
    def get_zchars_at(self, loc):
        return self.zstrtozchars(self.memory[loc : loc + self.zchar_size(loc)])
            
    def print_nl(self):
        self.print_char('\n')
        
    def zchar_size(self, loc):
        '''Determine the size in bytes of a z-encoded string at the given location'''
        if loc == 0:
            raise RuntimeError("zchar_size: Attempt to print string at address 0")
        end = loc
        while not self.memory[end] & 0x80:
            end += 2
        return end + 2 - loc
    
    def zstrtozchars(self, zstr):
        '''Convert a z-encoded byte string to a list of zchars'''
        zchars = []
        for i in range(0, len(zstr), 2):
            zchars.append(zstr[i] >> 2 & 0x1F)
            zchars.append( (zstr[i] & 0x03) << 3 | zstr[i+1] >> 5)
            zchars.append( zstr[i+1] & 0x1F)
        s = []
        alphabet = 0
        abbreviation = 0
        zchar10count = 0
        zchar10 = 0
                
        for z in zchars:
            
            if zchar10count > 0:
                zchar10 = (zchar10 << 5) | z
                zchar10count -= 1
                if zchar10count == 0:
                    s.append(zchar10)
                
            elif abbreviation != 0:
                abbreviation_loc = self.get_abbreviation_loc(32 * (abbreviation-1) + z)
                size = self.zchar_size(abbreviation_loc)
                s += self.zstrtozchars(self.memory[abbreviation_loc : abbreviation_loc + size])
                abbreviation = 0
                
            elif alphabet == 2 and z == 6:
                zchar10count = 2
                zchar10 = 0
                alphabet = 0
                
            elif z >= 6:
                s.append(ord(zchars_alphabet[alphabet][z-6]))
                alphabet = 0
                
            elif z == 0:
                s.append(ord(' '))
                
            elif 1 <= z <= 3:
                abbreviation = z
                
            elif z == 4:
                alphabet = 1
                
            elif z == 5:
                alphabet = 2  
                
        return s
        
    def zstrtozscii(self, zstr):
        '''Convert a z-encoded byte string to a byte string of zchars'''
        return bytes(self.zstrtozchars(zstr)) 
    
    def get_abbreviation_loc(self, n):
        return self.unpacked_addr(self.get_word_at(self.header.abbreviation_table_loc + n*2))
    
    def update_status(self, obj, score):
        self.print_str(''.join([obj, " ", score, "\n"]))

    def print_status(self):
        if self.header.version <= 3:
            time_game = self.header.flags1 & 0x02
            g00 = self.get_global(0)
            if g00 == 0:
                raise RuntimeError("print_status: attempt to print object 0")
            addr = self.get_object_short_description_addr(g00)
            obj = self.zcharlist_to_str(self.get_zchars_at(addr))
            if time_game:
                score = time.strftime("%H:%M")
            else:
                score = "{:d} / {:d}".format(word_to_signed(self.get_global(1)), 
                                                        word_to_signed(self.get_global(2)))
            self.status_thing.update_status(obj, score)
            
        
    def zcharlist_to_str(self, zcharlist):
        '''Convert a list of zchars to a Python string'''
        return ''.join(map(self.zchar_to_char, zcharlist))
    
    def _zscii_filter(self, c):
        '''Determine if a given Unicode character can be translated to a zscii character'''
        return c == '\n' or 32 <= ord(c) <= 126 or c in Memory.default_translation
    
    def split_with_separators(self, text_buffer, separators):
        '''Split an sread-style text_buffer at separators, stripping the result of whitespace.
           The result is a list of tuples. The first element is the index into the original string
           where the word starts, and the second element is the word (byte string).
           Separators (but not whitespace) are part of the returned list.'''
        maxchars = self.memory[text_buffer]
        buffer = itertools.takewhile(lambda x: x!=b'\x00', 
                                     memoryview(self.memory)[text_buffer:text_buffer+1+maxchars])
        splits = [None]
        for pos, b in enumerate(buffer):
            if pos == 0:
                continue
            if b in separators:
                if splits[-1] == None:
                    splits[-1] = pos, b
                else:
                    splits.append( (pos, b) )
                splits.append(None)
            elif b == b' ':
                if splits[-1] != None:
                    splits.append(None)
            else:
                if splits[-1] == None:
                    splits[-1] = pos, b
                else:
                    p, s = splits[-1]
                    splits[-1] = p, s+b
        if splits[-1] == None:
            splits = splits[:-1]
        return splits
        
    def parse_by_dictionary(self, text_buffer, parse_buffer, dictionary_loc):
        '''Parse the given sread-style text buffer using the given dictionary'''
        num_separators = self.memory[dictionary_loc]
        word_separators = self.memory[dictionary_loc + 1 : dictionary_loc + 1 + num_separators]
        entry_length = self.memory[dictionary_loc + 1 + num_separators]
        num_entries = self.get_word_at(dictionary_loc + 2 + num_separators)
        start_entry_loc = dictionary_loc + 4 + num_separators
        entry_word_size = 4 if self.header.version <= 3 else 6

        words = self.split_with_separators(text_buffer, word_separators)
        maxwords = self.memory[parse_buffer]
        parsed_words = []
        
        print("PARSE: ", end='')
        print(words)
        
        for _, word in words:
            bword = zsciitozstr(word, entry_word_size>>1)
            entry_loc = 0
            for loc in range(start_entry_loc, start_entry_loc + num_entries * entry_length, entry_length):
                if bword == self.memory[loc : loc + entry_word_size]:
                    entry_loc = loc
                    break
            parsed_words.append(entry_loc)
        
        parsed_words = parsed_words[:maxwords]
        self.memory[parse_buffer + 1] = len(parsed_words)
        print("PARSED WORDS: ", end='')
        for i, parsed_word in enumerate(parsed_words):
            self.set_word_at(parse_buffer + 2 + i*4, parsed_word)
            pos, word = words[i]
            self.memory[parse_buffer + 4 + i*4] = len(word)
            self.memory[parse_buffer + 5 + i*4] = pos
            print("({:04X}, {:02X}, {:02X}) ".format(parsed_word, len(word), pos), end='')
        print()
        
        
    def sread_parse_text_buffer(self, text_buffer, parse_buffer):
        '''Parse the given sread-style text buffer using the default dictionary'''
        self.parse_by_dictionary(text_buffer, parse_buffer, self.header.dictionary_loc)
        
    def sread_readline_to_text_buffer(self, text_buffer):
        '''Read a line from the input stream, convert to zscii, ignore any invalid characters,
           remove any final newline, and place in text_buffer up to size of text_buffer, terminating with 0.'''
        maxchars = self.memory[text_buffer]
        offset = 0 if self.header.version < 5 else self.memory[text_buffer + 1]
        start = 1 if self.header.version < 5 else offset + 2
        cmd = self.input_stream.readline()
        cmd = map(str.lower, filter(self._zscii_filter, cmd))
        zchars = list( map(self.char_to_zchar, itertools.islice(cmd, maxchars-offset) ) )
        zchars[-1] = 0
        self.memory[text_buffer + start : text_buffer + start + len(zchars)] = zchars
    
    def sread(self, text_buffer, parse_buffer):
        '''Read a line from input and parse it'''
        self.sread_readline_to_text_buffer(text_buffer)
        self.sread_parse_text_buffer(text_buffer, parse_buffer)
        
    def unpacked_addr(self, addr, is_call = False):
        if self.header.version <= 3:
            return 2*addr
        elif self.header.version <= 5:
            return 4*addr
        elif self.header.version <= 7:
            base = 4*addr
            if is_call:
                return base + self.header.routines_offest
            return base + self.header.static_strings_offset
        else:
            return 8*addr
        
    def get_word_at(self, loc):
        '''Gets the 2-byte word at loc'''
        return self.memory[loc]<<8 | self.memory[loc+1]
    
    def set_word_at(self, loc, value):
        '''Sets the 2-byte word at loc to value'''
        self.memory[loc] = value >> 8
        self.memory[loc + 1] = value & 0xFF
        
    def get_object_base(self, obj):
        '''Get the start location of the object's entry in the object table'''
        if obj == 0:
            raise RuntimeError("Attempt to access object 0")
        if self.header.version <= 3:
            return self.header.object_table_loc + 9*(obj-1) + 62
        return self.header.object_table_loc + 14*(obj-1) + 126
    
    def get_object_default_prop_base(self):
        '''Get the start location of the default properties in the object table'''
        return self.header.object_table_loc
    
    def get_object_prop_base(self, obj):
        '''Get the start location of the object's property table'''
        base = self.get_object_base(obj)
        if self.header.version <= 3:
            return self.get_word_at(base + 7)
        return self.get_word_at(base + 12)        
        
    def get_object_parent(self, obj):
        base = self.get_object_base(obj)
        if self.header.version <= 3:
            return self.memory[base + 4]
        return self.get_word_at(base + 6)
    
    def get_object_sibling(self, obj):
        base = self.get_object_base(obj)
        if self.header.version <= 3:
            return self.memory[base + 5]
        return self.get_word_at(base + 8)
    
    def get_object_child(self, obj):
        base = self.get_object_base(obj)
        if self.header.version <= 3:
            return self.memory[base + 6]
        return self.get_word_at(base + 10)
    
    def set_object_parent(self, obj, parent):
        base = self.get_object_base(obj)
        if self.header.version <= 3:
            self.memory[base + 4] = parent
            return
        self.set_word_at(base + 6, parent)
    
    def set_object_sibling(self, obj, sibling):
        base = self.get_object_base(obj)
        if self.header.version <= 3:
            self.memory[base + 5] = sibling
            return
        self.set_word_at(base + 8, sibling)
    
    def set_object_child(self, obj, child):
        base = self.get_object_base(obj)
        if self.header.version <= 3:
            self.memory[base + 6] = child
            return
        self.set_word_at(base + 10, child)
    
    def get_object_attr_loc_mask(self, obj, attr):
        '''Returns a tuple of the location of the object's attribute and the
            bit mask to access it'''
        offset = attr >> 3
        bitmask = 1 << (7 - (attr&0x07))
        base = self.get_object_base(obj)
        return base+offset, bitmask
    
    def get_object_attr(self, obj, attr):
        loc, bitmask = self.get_object_attr_loc_mask(obj, attr)
        return self.memory[loc] & bitmask

    def set_object_attr(self, obj, attr):
        loc, bitmask = self.get_object_attr_loc_mask(obj, attr)
        self.memory[loc] |= bitmask

    def clr_object_attr(self, obj, attr):
        loc, bitmask = self.get_object_attr_loc_mask(obj, attr)
        self.memory[loc] &= (~bitmask & 0xFF)
        
    def insert_object(self, obj, dest):
        if obj == 0:
            raise RuntimeError("insert_object: Attempt to insert object 0")
        
        # First remove the object from its sibling chain, if any
        
        parent = self.get_object_parent(obj)
        if parent != 0:
            prevchild = 0
            child = self.get_object_child(parent)
            while child != 0 and child != obj:
                prevchild = child
                child = self.get_object_sibling(child)
            if child == 0:
                raise RuntimeError("Inconsistent object child chain for object {:02X}"
                                   " (object {:02X} not found)".format(parent, obj))
            if prevchild == 0:
                self.set_object_child(parent, self.get_object_sibling(obj))
            else:
                self.set_object_sibling(prevchild, self.get_object_sibling(obj))
                
        # Now put the object into the new sibling chain
        
        if dest != 0:
            dest_child = self.get_object_child(dest)
            self.set_object_child(dest, obj)
            self.set_object_sibling(obj, dest_child)
        self.set_object_parent(obj, dest)
        
    def get_object_short_description_addr(self, obj):
        start_loc = self.get_object_prop_base(obj)
        return start_loc + 1
    
    def get_object_prop_size(self, prop_loc):
        '''Decode the property size'''
        if self.header.version <= 3:
            return (self.memory[prop_loc]>>5) + 1
        if not self.memory[prop_loc]&0x80:
            return 2 if self.memory[prop_loc]&0x40 else 1
        return self.memory[prop_loc+1]&0x3F
        
    def get_object_prop_size_from_addr(self, prop_addr):
        '''Given the start of the property data, step backwards to find the property size'''
        if self.header.version <= 3:
            return (self.memory[prop_addr-1]>>5) + 1
        if not self.memory[prop_addr-1]&0x80:
            return 2 if self.memory[prop_addr-1]&0x40 else 1
        return self.memory[prop_addr-1]&0x3F
        
        
    def get_object_prop_addr_and_size(self, obj, prop):
        '''Return the location and size of the object's property data 
         for the given property, or 0,0 if no such property.'''
        start_loc = self.get_object_prop_base(obj)
        if prop == 0:
            raise RuntimeError("get_object_prop_addr_and_size: attempt to access object {:02X} property 0".format(obj))
        text_length = self.memory[start_loc]
        prop_loc = start_loc + 1 + 2*text_length
        
        # Format for v3 and below property table is simple.
        # The first byte is zero for end of table, otherwise the bottom 5 bits is the
        # property number, and the top 3 bits, plus one, is the size (so size is 1 through 8).
        # Then follows the given (size) number of bytes for the property data.
        # Also, the properties are required to be in descending property number order,
        # which means that if you see a property number that is lower than the one you're looking
        # for, you can just exit.
        
        if self.header.version <= 3:
            while self.memory[prop_loc]:
                size = self.get_object_prop_size(prop_loc)
                propnum = self.memory[prop_loc]&0x1F
                prop_loc += 1
                if propnum < prop:
                    return 0, 0
                if propnum == prop:
                    return prop_loc, size
                prop_loc += size
            # fell off the end of the property table, oh noes
            return 0, 0
        
        # More complicated v4+ property table.
        # The property number is the bottom 6 bits of the first byte.
        # If bit 7 is clear, then bit 6 says how many bytes are in the data: a clear bit means 1,
        # and a set bit means 2.
        # But if bit 7 is set, then the next byte's bottom 6 bits is the size of the data that follows.
        
        while self.memory[prop_loc]:
            size = self.get_object_prop_size(prop_loc)
            propnum = self.memory[prop_loc]&0x3F
            if not self.memory[prop_loc]&0x80: # how many bytes in size
                prop_loc += 1
            else:
                prop_loc += 2
            if propnum < prop:
                return 0, 0
            if propnum == prop:
                return prop_loc, size
            prop_loc += size
        # fell off the end of the property table, oh noes
        return 0, 0        
        
    def get_object_prop_data(self, obj, prop):
        '''Return a byte array containing the object's property data for the given property.'''
        addr, size = self.get_object_prop_addr_and_size(obj, prop)
        if not addr:
            defaults = self.get_object_default_prop_base()
            addr = defaults + ((prop-1)<<1)
            return self.memory[addr : addr + 2]
        return self.memory[addr : addr + size]
    
    def set_object_prop_data(self, obj, prop, data):
        addr, size = self.get_object_prop_addr_and_size(obj, prop)
        if not addr:
            raise RuntimeError("set_object_prop_data: Property {:02X} in object {:04X} does not exist.".format(prop, obj))
        if size == 1:
            self.memory[addr] = data & 0xFF
        elif size == 2:
            self.set_word_at(addr, data)
    
    def get_object_next_prop(self, obj, prop):
        '''Return the number of the property after the given property, 
           or None if the given property does not exist. 
           The property returned may be zero if there is no next property. If
           the property given is 0, return the first property number.'''
        if prop == 0:
            start_loc = self.get_object_prop_base(obj)
            text_length = self.memory[start_loc]
            prop_loc = start_loc + 1 + 2*text_length
        else:   
            addr, size = self.get_object_prop_addr_and_size(obj, prop)
            if addr == 0: # no such property
                raise RuntimeError("get_object_next_prop: Property {:02X} in object {:04X} does not exist.".format(prop, obj))
            prop_loc = addr + size
        
        if self.header.version <= 3:
            return self.memory[prop_loc]&0x1F
        return self.memory[prop_loc]&0x3F
       


    def get_global(self, gnum):
        base = self.header.global_var_table_loc + 2*gnum
        return self.get_word_at(base)
    
    def set_global(self, gnum, value):
        base = self.header.global_var_table_loc + 2*gnum
        self.set_word_at(base, value)
            
            
class Header:
    def __init__(self, memory):
        self.memory = memory
        
        # read the non-modifiable data
        
        (self.version,
         self.flags1,
         self.highmem_base,
         self.initial_PC,
         self.dictionary_loc, 
         self.object_table_loc,
         self.global_var_table_loc,
         self.static_mem_base,
         self.abbreviation_table_loc,
         self.file_length,
         self.checksum,
         self.interpreter_number,
         self.interpreter_version,
         self.screen_height_lines,
         self.screen_width_chars,
         self.screen_height_units,
         self.screen_width_units,
         self.font_width_units,
         self.font_height_units,
         self.routines_offset,
         self.static_strings_offset,
         self.default_background_color,
         self.default_foreground_color,
         self.pixel_width_stream_3,
         self.revision_number,
         self.alphabet_table_addr,
         self.header_ext_table_addr) = struct.unpack(">BBxxHHHHHHxxxxxxxxHHHBBBBHHBBHHBBxxHHHH", self.memory[:0x38])
         
        # FIXME: Address of terminating characters table
        # FIXME: flags2
        # FIXME: header extension
        
        self.routines_offset *= 8
        self.static_strings_offset *= 8
        
        # fixups for various versions
        
        if self.version <= 3: self.file_length *= 2
        elif self.version <= 5: self.file_length *= 4
        else: self.file_length *= 8
        
    def __str__(self):
        return ("Version            : {:02d}\n"
                "Base of high memory: {:04X}\n"
                "Initial PC         : {:04X}\n"
                "Dictionary loc     : {:04X}\n"
                "Object table loc   : {:04X}\n"
                "Global var table   : {:04X}\n"
                "Static mem base    : {:04X}\n"
                "Abbrev table loc   : {:04X}\n"
                "File length        : {:08d}\n").format(self.version,
                                                         self.highmem_base,
                                                         self.initial_PC,
                                                         self.dictionary_loc,
                                                         self.object_table_loc,
                                                         self.global_var_table_loc,
                                                         self.static_mem_base,
                                                         self.abbreviation_table_loc,
                                                         self.file_length)

# See http://www.gnelson.demon.co.uk/zspec/sect14.html for intruction descriptions.

class Instruction_Interpreter:

    # opcode tables are:
    # 0: 2OP
    # 1: 0OP
    # 2: 1OP
    # 3: VAR
    # 4: EXT
    
    opnames = [ [ "", "JE", "JL", "JG", "DEC_CHK", "INC_CHK", "JIN", "TEST", 
               "OR", "AND", "TEST_ATTR", "SET_ATTR", "CLEAR_ATTR", "STORE", "INSERT_OBJ", "LOADW",
               "LOADB", "GET_PROP", "GET_PROP_ADDR", "GET_NEXT_PROP", "ADD", "SUB", "MUL", "DIV",
               "MOD", "CALL_2S", "CALL_2N", "SET_COLOUR", "THROW", "", "", ""],
               [ "RTRUE", "RFALSE", "PRINT", "PRINT_RET", "NOP", "SAVE", "RESTORE", "RESTART",
                 "RET_POPPED", "POP", "QUIT", "NEW_LINE", "SHOW_STATUS", "VERIFY", "", "PIRACY"],
               [ "JZ", "GET_SIBLING", "GET_CHILD", "GET_PARENT", "GET_PROP_LEN", "INC", "DEC", "PRINT_ADDR",
                "CALL_1S", "REMOVE_OBJ", "PRINT_OBJ", "RET", "JUMP", "PRINT_PADDR", "LOAD", "NOT"],
               [ "CALL", "STOREW", "STOREB", "PUT_PROP", "SREAD", "PRINT_CHAR", "PRINT_NUM", "RANDOM",
                "PUSH", "PULL", "SPLIT_WINDOW", "SET_WINDOW", "CALL_VS2", "ERASE_WINDOW", "ERASE_LINE",
                "SET_CURSOR", "GET_CURSOR", "SET_TEXT_STYLE", "BUFFER_MODE", "OUTPUT_STREAM",
                "INPUT_STREAM", "SOUND_EFFECT", "READ_CHAR", "SCAN_TABLE", "NOT",
                "CALL_VN", "CALL_VN2", "TOKENIZE", "ENCODE_TEXT", "COPY_TABLE", "PRINT_TABLE",
                "CHECK_ARG_COUNT"
                ] ]

    neg_opnames = [ { 0x01:"JNE", 0x02:"JGE", 0x03:"JLE", 0x04:"DEC_CHKGE", 0x05:"INC_CHKLE", 
                     0x06:"JNIN" },
                   { },
                   { 0x00:"JNZ" },
                   { } ]
    
    # (opcode_type, op): start version number
    # (opcode_type, op, version number): 0
    
    branch_instr_dict = { (0, 0x01): 1, (0, 0x02): 1, (0, 0x03): 1, (0, 0x04): 1, (0, 0x05): 1, 
                          (0, 0x06): 1, (0, 0x07): 1, (0, 0x0A): 1, 
                          
                          (1, 0x05, 1): 0, (1, 0x05, 2): 0, (1, 0x05, 3): 0, 
                          (1, 0x06, 1): 0, (1, 0x06, 2): 0, (1, 0x06, 3): 0, 
                          (1, 0x0D): 3, (1, 0x0F): 5,
                          
                          (2, 0x00): 1, (2, 0x01): 2, (2, 0x02): 1,
                          
                          (3, 0x17): 4, (3, 0x1F): 5,
                          
                          (4, 0x06): 6, (4, 0x18): 6, (4, 0x1B): 6 }
    
    store_instr_dict = { (0, 0x08): 1, (0, 0x09): 1, (0, 0x0F): 1, 
                         (0, 0x10): 1, (0, 0x11): 1, (0, 0x12): 1, (0, 0x13): 1, (0, 0x14): 1, 
                         (0, 0x15): 1, (0, 0x16): 1, (0, 0x17): 1, (0, 0x18): 1, 
                         (0, 0x19): 4,
                         
                         (1, 0x09): 5,
                         
                         (2, 0x01): 1, (2, 0x02): 1, (2, 0x03): 1, (2, 0x04): 1, (2, 0x0E): 1,
                         (2, 0x0F, 1): 0, (2, 0x0F, 2): 0, (2, 0x0F, 3): 0, (2, 0x0F, 4): 0, 
                         
                         (3, 0x00): 1, (3, 0x04): 5, (3, 0x07): 1, (3, 0x09): 6,
                         (3, 0x0C): 4, (3, 0x16): 4, (3, 0x17): 4, (3, 0x18): 5,
                         
                         (4, 0x00): 5, (4, 0x01): 5, (4, 0x02): 5, (4, 0x03): 5, (4, 0x04): 5,
                         (4, 0x09): 5, (4, 0x0A): 5, (4, 0x19): 6 }
    
    call_instr_dict = { (0, 0x19): 4, (0, 0x1A): 5,
                       
                        (2, 0x08): 4, (2, 0x0F): 5,
                        
                        (3, 0x00): 1, (3, 0x0C): 4, (3, 0x19): 5, (3, 0x1A): 5 }
    
    opcode_type_dict = { 0: "long, 2-operand", 1: "short, 0-operand", 2: "short, 1-operand",
                         3: "variable", 4: "extended" }
    
    operand_type_dict = { 0: "word", 1: "byte", 2: "var", 3: "not present" }

    def __init__(self, memory):
        self.memory = memory
        self.jumptable = [ ### 2OP
                          [self._illegal,
                              self._je, 
                              self._jl, 
                              self._jg, 
                              self._dec_chk,
                              self._inc_chk, 
                              self._jin, 
                              self._test, 
                              self._or,
                              self._and,
                              self._test_attr,
                              self._set_attr,
                              self._clear_attr,
                              self._store,
                              self._insert_obj,
                              self._loadw,
                              self._loadb,
                              self._get_prop,
                              self._get_prop_addr,
                              self._get_next_prop,
                              self._add,
                              self._sub,
                              self._mul,
                              self._div,
                              self._mod,
                              self._call_2s if self.memory.header.version >= 4 else self._illegal,
                              self._call_2n if self.memory.header.version >= 5 else self._illegal,
                              self._set_colour if self.memory.header.version >= 5 else self._illegal,
                              self._throw if self.memory.header.version >= 5 else self._illegal,
                              self._illegal,
                              self._illegal,
                              self._illegal
                              ],
                          ### 0OP
                          [self._rtrue,
                           self._rfalse,
                           self._print,
                           self._print_ret,
                           self._nop,
                           self._save if self.memory.header.version <= 4 else self._illegal,
                           self._restore if self.memory.header.version <= 4 else self._illegal,
                           self._restart,
                           self._ret_popped,
                           self._pop if self.memory.header.version <= 4 else self._catch,
                           self._quit,
                           self._new_line,
                           self._show_status if self.memory.header.version == 3 else self._illegal,
                           self._verify if self.memory.header.version >= 3 else self._illegal,
                           self._illegal, # actually extended
                           self._piracy if self.memory.header.version >= 5 else self._illegal
                           ],
                          ### 1OP
                          [self._jz,
                           self._get_sibling,
                           self._get_child,
                           self._get_parent,
                           self._get_prop_len,
                           self._inc,
                           self._dec,
                           self._print_addr,
                           self._call_1s if self.memory.header.version >= 4 else self._illegal,
                           self._remove_obj,
                           self._print_obj,
                           self._ret,
                           self._jump,
                           self._print_paddr,
                           self._load,
                           self._not if self.memory.header.version < 5 else self._call_1n,
                           self._rtrue,
                           self._rfalse,
                           self._print,
                           self._print_ret,
                           self._nop,
                           self._save if self.memory.header.version < 5 else self._illegal,
                           self._restore if self.memory.header.version < 5 else self._illegal,
                           self._restart,
                           self._ret_popped,
                           self._pop if self.memory.header.version < 5 else self._catch,
                           self._quit,
                           self._new_line,
                           self._show_status if self.memory.header.version == 3 else self._nop,
                           self._verify if self.memory.header.version >= 3 else self._illegal,
                           self._illegal, # actually slot for extended
                           self._piracy if self.memory.header.version >= 5 else self._illegal
                           ],
                          ### VAR
                          [self._call if self.memory.header.version < 4 else self._call_vs,
                           self._storew,
                           self._storeb,
                           self._put_prop,
                           self._sread if self.memory.header.version <= 4 else self._aread,
                           self._print_char,
                           self._print_num,
                           self._random,
                           self._push,
                           self._pull,
                           self._split_window if self.memory.header.version >= 3 else self._illegal,
                           self._set_window if self.memory.header.version >= 3 else self._illegal,
                           self._call_vs2 if self.memory.header.version >= 4 else self._illegal,
                           self._erase_window if self.memory.header.version >= 4 else self._illegal,
                           self._erase_line if self.memory.header.version >= 4 else self._illegal,
                           self._set_cursor if self.memory.header.version >= 4 else self._illegal,
                           self._get_cursor if self.memory.header.version >= 4 else self._illegal,
                           self._set_text_style if self.memory.header.version >= 4 else self._illegal,
                           self._buffer_mode if self.memory.header.version >= 4 else self._illegal,
                           self._output_stream if self.memory.header.version >= 3 else self._illegal,
                           self._input_stream if self.memory.header.version >= 3 else self._illegal,
                           self._sound_effect if self.memory.header.version >= 3 else self._illegal,
                           self._read_char if self.memory.header.version >= 4 else self._illegal,
                           self._scan_table if self.memory.header.version >= 4 else self._illegal,
                           self._not if self.memory.header.version >= 5 else self._illegal,
                           self._call_vn if self.memory.header.version >= 5 else self._illegal,
                           self._call_vn2 if self.memory.header.version >= 5 else self._illegal,
                           self._tokenize if self.memory.header.version >= 5 else self._illegal,
                           self._encode_text if self.memory.header.version >= 5 else self._illegal,
                           self._copy_table if self.memory.header.version >= 5 else self._illegal,
                           self._print_table if self.memory.header.version >= 5 else self._illegal,
                           self._check_arg_count if self.memory.header.version >= 5 else self._illegal,
                           ] ]
    
    def _illegal(self):
        raise IllegalInstructionError(self.instr_loc, self.opcode)
        
    def _je(self):
        self.condition = False
        for o in self.operands[1:]:
            self.condition = self.condition or self.operands[0] == o
        
    def _jl(self):
        self.condition = word_to_signed(self.operands[0]) < word_to_signed(self.operands[1])
        
    def _jg(self):
        self.condition = word_to_signed(self.operands[0]) > word_to_signed(self.operands[1])
     
    def _dec_chk(self):
        self.store_value = (self.get_variable(self.store_var) - 1) & 0xFFFF
        self.condition = word_to_signed(self.store_value) < word_to_signed(self.operands[1])
        
    def _inc_chk(self):
        self.store_value = (self.get_variable(self.store_var) + 1) & 0xFFFF
        self.condition = word_to_signed(self.store_value) > word_to_signed(self.operands[1])
        
    def _jin(self):
        parent = self.memory.get_object_parent(self.operands[0])
        self.condition = parent == self.operands[1]
        
    def _test(self):
        self.condition = self.operands[0] & self.operands[1] == self.operands[1]
    
    def _or(self):
        self.store_value = self.operands[0] | self.operands[1]

    def _and(self):
        self.store_value = self.operands[0] & self.operands[1]

    def _test_attr(self):
        attr = self.memory.get_object_attr(self.operands[0], self.operands[1])
        self.condition = attr
    
    def _set_attr(self):
        self.memory.set_object_attr(self.operands[0], self.operands[1])

    def _clear_attr(self):
        self.memory.clr_object_attr(self.operands[0], self.operands[1])

    def _store(self):
        self.store_value = self.operands[1]

    def _insert_obj(self):
        self.memory.insert_object(self.operands[0], self.operands[1])
    
    def _loadw(self):
        self.store_value = self.memory.get_word_at(self.operands[0] + 2*self.operands[1])
        
    def _loadb(self):
        self.store_value = self.memory.memory[self.operands[0] + self.operands[1]]
       
    def _get_prop(self):
        prop = self.memory.get_object_prop_data(self.operands[0], self.operands[1])
        if len(prop) == 1:
            self.store_value = prop[0]
        elif len(prop) == 2:
            self.store_value = prop[0]<<8 | prop[1]
        else:
            raise RuntimeError("GET_PROP called for object {:02X} "
                               "property {:02X} but property length "
                               "{:02X} is greater than 2.".format(self.operands[0], self.operands[1], len(prop)))
            
    def _get_prop_addr(self):
        addr, _ = self.memory.get_object_prop_addr_and_size(self.operands[0], self.operands[1])
        self.store_value = addr

    def _get_next_prop(self):
        prop = self.memory.get_object_next_prop(self.operands[0], self.operands[1])
        # FIXME: die if prop == None
        self.store_value = prop

    def _add(self):
        self.store_value = (self.operands[0] + self.operands[1]) & 0xFFFF
        
    def _sub(self):
        self.store_value = (self.operands[0] - self.operands[1]) & 0xFFFF
        
    def _mul(self):
        self.store_value = (self.operands[0] * self.operands[1]) & 0xFFFF
        
    def _div(self):
        # FIXME: die if operands[1] == 0
        self.store_value = self.operands[0] // self.operands[1]
        
    def _mod(self):
        # FIXME: die if operands[1] == 0
        self.store_value = self.operands[0] % self.operands[1]
    
    def _call_2s(self):
        pass
    
    def _call_2n(self):
        pass
    
    def _set_colour(self):
        raise NotImplementedError("SET_COLOUR is not yet implemented")
    
    def _throw(self):
        raise NotImplementedError("THROW is not yet implemented")
    
    def _jz(self):
        self.condition = self.operands[0] == 0
    
    def _get_sibling(self):
        sibling = self.memory.get_object_sibling(self.operands[0])
        self.store_value = sibling
        self.condition = sibling != 0
    
    def _get_child(self):
        child = self.memory.get_object_child(self.operands[0])
        self.store_value = child
        self.condition = child != 0
    
    def _get_parent(self):
        self.store_value = self.memory.get_object_parent(self.operands[0])
        
    def _get_prop_len(self):
        if self.operands[0] <= 0x40:
            raise RuntimeError("GET_PROP_LEN: Attempt to access non-object address {:04X}".format(self.operands[0]))
        self.store_value = self.memory.get_object_prop_size_from_addr(self.operands[0])
        
    def _inc(self):
        self.store_value = (self.get_variable(self.operands[0]) + 1) & 0xFFFF
        
    def _dec(self):
        self.store_value = (self.get_variable(self.operands[0]) - 1) & 0xFFFF
        
    def _print_addr(self):
        self.memory.print_zstr_at(self.operands[0])
        
    def _call_1s(self):
        pass
        
    def _remove_obj(self):
        self.memory.insert_object(self.operands[0], 0)
        
    def _print_obj(self):
        self.memory.print_zstr_at(self.memory.get_object_short_description_addr(self.operands[0]))
        
    def _ret(self):
        self.return_value = self.operands[0]
        
    def _jump(self):
        self.condition = True
        
    def _print_paddr(self):
        self.memory.print_zstr_at(self.memory.unpacked_addr(self.operands[0]))
        
    def _load(self):
        self.store_value = self.get_variable(self.operands[0])
        
    def _not(self):
        self.store_value = (~self.operands[0]) & 0xFFFF
        
    def _call_1n(self):
        pass
        
    def _rtrue(self):
        self.return_value = 1
        
    def _rfalse(self):
        self.return_value = 0
        
    def _print(self):
        self.memory.print_zstr_at(self.operands[0])
        
    def _print_ret(self):
        self.memory.print_zstr_at(self.operands[0])
        self.memory.print_nl()
        self.return_value = 1
        
    def _nop(self):
        pass
    
    def _save(self):
        raise NotImplementedError("SAVE is not yet implemented")
    
    def _restore(self):
        raise NotImplementedError("RESTORE is not yet implemented")
    
    def _restart(self):
        raise NotImplementedError("RESTART is not yet implemented")
    
    def _ret_popped(self):
        self.return_value = self.stack.pop()
        
    def _pop(self):
        self.stack.pop()
        
    def _quit(self):
        exit()
        
    def _new_line(self):
        self.memory.print_nl()
        
    def _show_status(self):
        self.memory.print_status()
        
    def _verify(self):
        self.condition = sum(self.memory.memory[0x40:]) & 0xFFFF == self.memory.header.checksum
        
    def _piracy(self):
        self.condition = True
        
    def _call(self):
        pass
    
    def _call_vs(self):
        pass
    
    def _storew(self):
        self.memory.set_word_at(self.operands[0] + 2*self.operands[1], self.operands[2] & 0xFFFF)
        
    def _storeb(self):
        self.memory.memory[self.operands[0] + self.operands[1]] = self.operands[2] & 0xFF
        
    def _put_prop(self):
        self.memory.set_object_prop_data(self.operands[0], self.operands[1], self.operands[2])
        
    def _sread(self):
        if self.memory.header.version <= 3:
            self.memory.print_status()
        self.memory.sread(self.operands[0], self.operands[1])
        
    def _aread(self):
        self.memory.sread()
        
    def _print_char(self):
        self.memory.print_zscii(self.operands[0])
        
    def _print_num(self):
        val = self.operands[0]
        self.memory.print_num(word_to_signed(val))
        
    def _random(self):
        if self.operands[0] >= 0x8000:
            random.seed(0x7FFF - self.operands[0])
            self.store_value = 0
        elif self.operands[0] == 0:
            random.seed()
            self.store_value = 0
        else:
            self.store_value = random.randint(1, self.operands[0])
        
    def _push(self):
        self.stack.append(self.operands[0])
        
    def _pull(self):
        self.store_value = self.stack.pop()
        
    def _split_window(self):
        raise NotImplementedError("SPLIT_WINDOW is not yet implemented")
    
    def _set_window(self):
        raise NotImplementedError("SET_WINDOW is not yet implemented")
    
    def _call_vs2(self):
        pass
    
    def _erase_window(self):
        raise NotImplementedError("ERASE_WINDOW is not yet implemented")
    
    def _erase_line(self):
        raise NotImplementedError("ERASE_LINE is not yet implemented")
    
    def _set_cursor(self):
        raise NotImplementedError("SET_CURSOR is not yet implemented")
    
    def _get_cursor(self):
        raise NotImplementedError("GET_CURSOR is not yet implemented")

    def _set_text_style(self):
        raise NotImplementedError("SET_TEXT_STYLE is not yet implemented")

    def _buffer_mode(self):
        raise NotImplementedError("BUFFER_MODE is not yet implemented")
        
    def _output_stream(self):
        raise NotImplementedError("OUTPUT_STREAM is not yet implemented")

    def _input_stream(self):
        raise NotImplementedError("INPUT_STREAM is not yet implemented")
        
    def _sound_effect(self):
        raise NotImplementedError("SOUND_EFFECT is not yet implemented")

    def _read_char(self):
        raise NotImplementedError("READ_CHAR is not yet implemented")
    
    def _scan_table(self):
        raise NotImplementedError("SCAN_TABLE is not yet implemented")
    
    def _call_vn(self):
        pass
    
    def _call_vn2(self):
        pass
    
    def _tokenize(self):
        raise NotImplementedError("TOKENIZE is not yet implemented")
        
    def _encode_text(self):
        raise NotImplementedError("ENCODE_TEXT is not yet implemented")
    
    def _copy_table(self):
        raise NotImplementedError("COPY_TABLE is not yet implemented")
        
    def _print_table(self):
        raise NotImplementedError("PRINT_TABLE is not yet implemented")
    
    def _check_arg_count(self):
        raise NotImplementedError("CHECK_ARG_COUNT is not yet implemented")
    
    #
    # End of instructions
    #
    
    def verify(self):
        self._verify()
    
    def get_variable(self, var, peek = False):
        '''Get a 2-byte value for a variable'''
        if var == 0: # pop the stack
            if peek:
                return self.stack[-1]
            return self.stack.pop()
        
        if var <= 0x0F: # local variable var-1
            return self.local_vars[var-1]
        
        # otherwise global var-0x10
        return self.memory.get_global(var-0x10)
    
    def _trace_variable(self, var):
        if var == 0: # pop the stack
            return "(STACK)"
        
        if var <= 0x0F: # local variable var-1
            return "L{:02X}".format(var-1)
        
        # otherwise global var-0x10
        return "G{:02X}".format(var-0x10)
        
    def set_variable(self, var, value):
        '''Set a 2-byte value for a variable'''
        if var == 0: # push the stack
            self.stack.append(value)
        
        elif var <= 0x0F: # local variable var-1
            self.local_vars[var-1] = value
        
        else: # otherwise global var-0x10
            self.memory.set_global(var-0x10, value)
    
    def _get_routine_locals(self, routine_addr, return_addr, args):
        num_locals = self.memory.memory[routine_addr]
        routine_addr += 1
        if self.memory.header.version <= 4:
            hi = self.memory.memory[routine_addr : routine_addr + 2*num_locals : 2]
            lo = self.memory.memory[routine_addr+1 : routine_addr + 2*num_locals + 1 : 2]
            local_list = [h<<8 | l for h, l in zip(hi, lo)]
            routine_addr += 2*num_locals
        else:
            local_list = [0]*num_locals
        argvals, argtypes = args
        local_list[:len(argvals)] = argvals
        #for i, t in enumerate(argtypes):
        #    if t == 2:
        #        local_list[i] = self.get_variable(local_list[i])
        return local_list, routine_addr
        
    def _decode_branch(self, loc):
        self.branch_on_true = self.memory.memory[loc] & 0x80
        if self.memory.memory[loc] & 0x40: # short offset, 0 - 63
            branch_offset = self.memory.memory[loc] & 0x3F
            loc += 1
        else: # signed offset, 14 bits (3FFF)
            branch_offset = (self.memory.memory[loc] & 0x3F) << 8 | self.memory.memory[loc+1]
            if branch_offset >= 0x2000:
                branch_offset = branch_offset - 0x4000
            loc += 2
        if branch_offset == 0:
            self.return_value = 0
        elif branch_offset == 1:
            self.return_value = 1
        else:
            self.target = loc + branch_offset - 2 
        return loc
    
    def _decode_opcode(self, opcode, loc):
        opcode_type = opcode >> 6
         
        if opcode == 0xBE: # ext
            loc += 1
            op = self.memory.memory[loc]
            opcode_table = 5 # EXT

        if opcode_type == 0 or opcode_type == 1: # long
            opcode_table = 0 # 2OP
            op = opcode & 0x1F
            
        elif opcode_type == 2: # short
            if opcode & 0x30 == 0x30:
                opcode_table = 1 # 0OP
            else:
                opcode_table = 2 # 1OP
            op = opcode & 0x0F
            
        else: # var
            if opcode & 0x20:
                opcode_table = 3 # VAR
            else:
                opcode_table = 0 # 2OP
            op = opcode & 0x1F
            
        return loc, op, opcode_table
    
    def _decode_operands(self, opcode, op, loc):
        opcode_type = opcode >> 6
        
        if opcode_type == 0 or opcode_type == 1: # long
            self.operand_types = [2 if opcode & 0x40 else 1,
                                  2 if opcode & 0x20 else 1]
        elif opcode_type == 2: # short
            self.operand_types = [(opcode & 0x30) >> 4]

        else: # var
            self.operand_types = []
            loc += 1
            operand_type_byte = self.memory.memory[loc]
            while operand_type_byte & 0xFF != 0xFF:
                self.operand_types.append( (operand_type_byte & 0xC0) >> 6)
                operand_type_byte = (operand_type_byte << 2) | 0x03
            if operand_type_byte & 0x300 != 0x300 and (op == 0x0C or op == 0x1A): # can haz more operands
                loc += 1
                operand_type_byte = self.memory.memory[loc]
                while operand_type_byte & 0xFF != 0xFF:
                    self.operand_types.append( (operand_type_byte & 0xC0) >> 6)
                    operand_type_byte = (operand_type_byte << 2) | 0x03
                    
        return loc

    # See http://www.gnelson.demon.co.uk/zspec/sect04.html for opcode decoding

    def interpret_instr(self, loc, execute = True, traceop = False, prev_trace = None):
        if hasattr(self, "return_value"):
            del self.return_value
        self.instr_loc = loc
        self.opcode = self.memory.memory[loc]
        
        #
        # Determine which opcode table to use (0OP, 1OP, 2OP, VAR) and what the op is
        #
        
        loc, op, opcode_table = self._decode_opcode(self.opcode, loc)
        
        #
        # Determine number of operands, operand types, flags (branch, store, call)
        #
        
        loc = self._decode_operands(self.opcode, op, loc)
        
        branch_instr = (Instruction_Interpreter.branch_instr_dict.get( (opcode_table, op), 0xFFFFFFFF) <= self.memory.header.version or
            (opcode_table, op, self.memory.header.version) in Instruction_Interpreter.branch_instr_dict)
        store_instr = (Instruction_Interpreter.store_instr_dict.get( (opcode_table, op), 0xFFFFFFFF) <= self.memory.header.version or
            (opcode_table, op, self.memory.header.version) in Instruction_Interpreter.store_instr_dict)
        call_instr = (Instruction_Interpreter.call_instr_dict.get( (opcode_table, op), 0xFFFFFFFF) <= self.memory.header.version or
            (opcode_table, op, self.memory.header.version) in Instruction_Interpreter.call_instr_dict)
        
        op_name = Instruction_Interpreter.opnames[opcode_table][op]
            
        loc += 1
        
        # Decode the operands
        
        self.operands = []
        for t in self.operand_types:
            if t == 0:
                self.operands.append(self.memory.get_word_at(loc))
                loc += 2
            elif t == 1:
                self.operands.append(self.memory.memory[loc])
                loc += 1
            elif t == 2:
                # will be converted to var later
                self.operands.append(self.memory.memory[loc])
                loc += 1
                
        # A store instruction has an extra byte after the operands, which is a var
        
        if store_instr:
            self.store_var = self.memory.memory[loc]
            self.store_indirect = False
            loc += 1
                
        # A branch instruction has one or two bytes for the branch offset, or possibly
        # it means to return true or false.
        
        if branch_instr:
            loc = self._decode_branch(loc)
            
        # A call instruction has its first operand being the packed routine address to call,
        # (or an indirect reference to one) and the rest of the operands are the args. So
        # if the first operand is immediate, unpack it here. If not, we'll unpack it later
        
        if call_instr:
            if self.operand_types[0] != 2:
                self.routine_addr = self.memory.unpacked_addr(self.operands[0], True)
                self.operands[0] = self.routine_addr
            
            
        # Fix up flags and values for instructions which don't quite fit the neat
        # instruction format
        
        # Although INC_CHK and DEC_CHK store, they are not store instructions in that they
        # do not have a store_var after its operands.
        # Instead, operands[0] is the store_var. This means that if operands[0] is a var, the
        # store_var is the contents of that var; it would be an indirect store.
        
        if opcode_table == 0 and (op == 0x04 or op == 0x05):
            store_instr = True # fake it
            self.store_var = self.operands[0]
            self.store_indirect = self.operand_types[0] == 2
            self.operand_types[0] = 3 # not present
                
        # STORE likewise.
        
        elif opcode_table == 0 and op == 0x0D:
            store_instr = True # fake it
            self.store_var = self.operands[0]
            self.store_indirect = self.operand_types[0] == 2
            self.operand_types[0] = 3 # not present
                
        # INC and DEC, too.
        
        elif opcode_table == 2 and (op == 0x05 or op == 0x06):
            store_instr = True
            self.store_var = self.operands[0]
            self.store_indirect = self.operand_types[0] == 2
            self.operand_types[0] = 3 # not present
                
        # PULL also, but only for version 5 and below (i.e. when it's not a store_instr)
        
        elif opcode_table == 3 and op == 0x09 and not store_instr:
            store_instr = True
            if self.memory.header.version <= 5:
                self.store_var = self.operands[0]
                self.store_indirect = self.operand_types[0] == 2
                self.operand_types[0] = 3 # not present
            
        # JUMP is not a branch instruction, but its operand is its 2-byte signed offset,
        # and it goes to instr_loc + 1 + offset
        
        elif opcode_table == 2 and op == 0x0C:
            branch_instr = True
            self.branch_on_true = True
            self.condition = True
            branch_offset = word_to_signed(self.operands[0])
            self.target = self.instr_loc + 1 + branch_offset
            
        # PRINT and PRINT_RET have a zstr as their operand even though they
        # are not supposed to have an operand. So make it look like a PRINT_PADDR instruction.
        
        elif opcode_table == 1 and (op == 0x02 or op == 0x03):
            self.operands.append(loc)
            loc += self.memory.zchar_size(loc)
            
        trace = prev_trace
        if not execute or traceop:
            
            trace_instr = ""
            bytedata = "{:04X}: ".format(self.instr_loc)
            
            for i, b in enumerate(self.memory.memory[self.instr_loc : loc]):
                if i > 0 and (i%8)==0:
                    trace_instr += "{:31s}\n".format(bytedata)
                    bytedata = "{:4s}: ".format(" ")
                bytedata += "{:02X} ".format(b)
            trace_instr += "{:31s}".format(bytedata)
                
            traceops = ""
            
            # PRINT and PRINT_RET will just show their decoded zstr
            if opcode_table == 1 and (op == 0x02 or op == 0x03):
                zloc = self.operands[0]
                size = self.memory.zchar_size(zloc)
                traceops += self.memory.zcharlist_to_str(self.memory.zstrtozchars(self.memory.memory[zloc : zloc + size]))

            else:                
                opvals = []
                for t, o in zip(self.operand_types, self.operands):
                    if t == 0:
                        opvals.append("{:04X}".format(o))
                    elif t == 1:
                        opvals.append("{:02X}".format(o))
                    elif t == 2:
                        opvals.append(self._trace_variable(o))
                        
                if len(opvals) > 0:
                    for opval in opvals[:-1]:
                        traceops += opval
                        traceops += ", "
                    traceops += opvals[-1]
                
            if store_instr:
                traceops += " -> "
                if self.store_indirect:
                    traceops += "*"
                traceops += self._trace_variable(self.store_var)
                
            if branch_instr:
                traceops += " ?"
                if not self.branch_on_true:
                    traceops += "~"
                if hasattr(self, "return_value"):
                    traceops += "RFALSE" if self.return_value == 0 else " RTRUE"
                else:
                    traceops += "{:04X}".format(self.target)
                    
            trace_instr += "{:16s}{:30s}".format(op_name, traceops)
            if prev_trace != None:
                trace = prev_trace + "\n" + trace_instr
            else:
                trace = trace_instr
            
            debug_op = False
            if debug_op:
                print("  Opcode type: {:s}".format(Instruction_Interpreter.opcode_type_dict[self.opcode >> 6]))
                print("  Op: {:02X}".format(op))
                for i, operand in enumerate(self.operands):
                    print("  Operand {:d} type {:s} value {:04X}".format(i, 
                                                                       Instruction_Interpreter.operand_type_dict[self.operand_types[i]],
                                                                       operand))
                if call_instr:
                    print("  Call instruction")
                if store_instr:
                    print("  Store instruction")
                if branch_instr:
                    print("  Branch instruction")
                
        if not execute:
            return loc, trace
        
        
        # convert any var operands to their contents, which includes
        # an indirect store or indirect call if any
        
        if traceop:
            trace += " ;"

        for i, operand in enumerate(self.operands):
            if self.operand_types[i] == 2:
                contents = self.get_variable(operand)
                if call_instr and i == 0:
                    self.routine_addr = self.memory.unpacked_addr(contents, True)
                if traceop:
                    trace += " {:s} was {:04X} |".format(self._trace_variable(operand), contents)
                self.operands[i] = contents
                
        if store_instr and self.store_indirect:
            contents = self.get_variable(self.store_var)
            if traceop:
                trace += " [{:s} was {:04X}] |".format(self._trace_variable(self.store_var, contents))
            self.store_var = contents
            
        # For SREAD, trace before we wait for a line.
        
        if traceop and opcode_table == 3 and op == 0x04:
            print("{:s} (waiting for input)".format(trace))
            
        # Actually perform the instruction!
        
        self.jumptable[opcode_table][op]()
        
        # For SREAD, trace out the text and parse buffers
        
        if traceop and opcode_table == 3 and op == 0x04:
            text_buffer = self.operands[0]
            parse_buffer = self.operands[1]
            max_text = self.memory.memory[text_buffer]-1
            max_words = self.memory.memory[parse_buffer]
            words = self.memory.memory[parse_buffer+1]
            trace += " text_buffer (max {:d} chars): ".format(max_text)
            for z in self.memory.memory[text_buffer + 1 : text_buffer + 1 + max_text]:
                trace += "{:02X} ".format(z)
                if z == 0:
                    break
            trace += "| parse_buffer (max {:d} words, have {:d} words): ".format(max_words, words)
            for w in range(words):
                word_entry = self.memory.memory[parse_buffer + 2 + 4*w : parse_buffer + 6 + 4*w]
                trace += "({:04X}, {:02X}, {:02X}) ".format((word_entry[0]<<8) | word_entry[1], 
                                                            word_entry[2], word_entry[3])
            
        
        # do generic things for call, store, branch instructions
        
        if call_instr:
            if self.routine_addr == 0:
                if traceop:
                    trace += " call to 0000 always returns 0 |"
                self.store_value = 0
            else:
                self.args = (self.operands[1:], self.operand_types[1:])
                routine_locals, start_addr = self._get_routine_locals(self.routine_addr, loc, self.args)
                if traceop:
                    trace += " routine addr {:04X} | num_locals {:02X} | ".format(start_addr, len(routine_locals))
                    for l in routine_locals:
                        trace += "{:02X} ".format(l)
                routine_interpreter = Instruction_Interpreter(self.memory)
                self.store_value = routine_interpreter.interpret(start_addr, 
                                                                 trace = traceop, 
                                                                 local_vars = routine_locals,
                                                                 prev_trace = trace[:] if traceop else None)
                if traceop:
                    trace = "{:04X}: {:30s}{:16s}{:30s} ;".format(self.instr_loc, "", "", 
                                                                  " -> X" if not store_instr else " -> {:s}".format(self._trace_variable(self.store_var)))
                # return value will be ignored if not a store instruction
        
        if store_instr:
            self.set_variable(self.store_var, self.store_value)
            if traceop:
                trace += " {:s} now {:04X} |".format(self._trace_variable(self.store_var), self.store_value)
            
        if branch_instr:
            if traceop:
                trace += " branch_on_true {:d} | condition {:d} |".format(bool(self.branch_on_true), 
                                                                     bool(self.condition))
            if self.branch_on_true and self.condition or not self.branch_on_true and not self.condition:
                if hasattr(self, "return_value"):
                    if traceop:
                        trace += " Return value {:04X} |".format(self.return_value)
                    return -(self.return_value + 1), trace
                if traceop:
                    trace += " Branch goes to target {:04X} |".format(self.target)
                return self.target, trace
            if hasattr(self, "return_value"):
                del self.return_value
            if traceop:
                trace += " Branch not taken |"
                
        if hasattr(self, "return_value"):
            if traceop:
                trace += " Return value {:04X} |".format(self.return_value)
            return -(self.return_value + 1), trace
        
        return loc, trace
    
    # Interpret instructions until one of them returns a negative "location". This is
    # actually a return value, which is -location minus 1 (to allow zero return values)

    def interpret(self, start_loc, trace = False, trace_file = None, local_vars=[], prev_trace=None):
        self.stack = []
        self.local_vars = local_vars
        loc = start_loc
        while loc > 0:
            loc, traceval = self.interpret_instr(loc, execute = True, traceop = trace, prev_trace = prev_trace)
            if trace:
                print(traceval, file=trace_file)
                prev_trace = None
        return -loc-1


def main():
    parser = argparse.ArgumentParser(description='Z-Machine interpreter for Z-coded story files.')
    parser.add_argument('-V', '--version', action='version', version='%(prog)s 1.0 (2012-05-31)')
    parser.add_argument('-t', '--trace', action='store_true')
    parser.add_argument('inputFile', type=argparse.FileType('rb'), nargs=1,
                        help='The story file to run')
    
    args = parser.parse_args()
    
    memory = Memory(zfile = args.inputFile[0], zfilesize = os.path.getsize(args.inputFile[0].raw.name))
    print(memory.header)
    print(len(memory.memory))
    interpreter = Instruction_Interpreter(memory)
    loc = memory.header.initial_PC
    loc = interpreter.interpret(loc, trace=args.trace)

if __name__ == '__main__':
    main()